{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "f:\\pythonprojects\\lenv\\test\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n",
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import random\n",
    "import os\n",
    "\n",
    "import keras\n",
    "import numpy as np\n",
    "from keras.callbacks import LambdaCallback\n",
    "from keras.models import Input, Model, load_model\n",
    "from keras.layers import LSTM, Dropout, Dense\n",
    "from keras.optimizers import Adam\n",
    "\n",
    "from data_utils import *\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PoetryModel(object):\n",
    "    def __init__(self, config):\n",
    "        self.model = None\n",
    "        self.do_train = True\n",
    "        self.loaded_model = False\n",
    "        self.config = config\n",
    "\n",
    "        # 文件预处理\n",
    "        self.word2numF, self.num2word, self.words, self.files_content = preprocess_file(self.config)\n",
    "\n",
    "        self.poems = self.files_content.split(']')\n",
    "        self.poems_num = len(self.poems)\n",
    "        \n",
    "        # 如果模型文件存在则直接加载模型，否则开始训练\n",
    "        if os.path.exists(self.config.weight_file) and self.loaded_model:\n",
    "            self.model = load_model(self.config.weight_file)\n",
    "        else:\n",
    "            self.train()\n",
    "\n",
    "    def build_model(self):\n",
    "        '''建立模型'''\n",
    "        print('building model')\n",
    "\n",
    "        # 输入的dimension\n",
    "        input_tensor = Input(shape=(self.config.max_len, len(self.words)))\n",
    "        lstm = LSTM(512, return_sequences=True)(input_tensor)\n",
    "        dropout = Dropout(0.6)(lstm)\n",
    "        lstm = LSTM(256)(dropout)\n",
    "        dropout = Dropout(0.6)(lstm)\n",
    "        dense = Dense(len(self.words), activation='softmax')(dropout)\n",
    "        self.model = Model(inputs=input_tensor, outputs=dense)\n",
    "        optimizer = Adam(lr=self.config.learning_rate)\n",
    "        self.model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])\n",
    "\n",
    "    def sample(self, preds, temperature=1.0):\n",
    "        '''\n",
    "        当temperature=1.0时，模型输出正常\n",
    "        当temperature=0.5时，模型输出比较open\n",
    "        当temperature=1.5时，模型输出比较保守\n",
    "        在训练的过程中可以看到temperature不同，结果也不同\n",
    "        '''\n",
    "        preds = np.asarray(preds).astype('float64')\n",
    "        exp_preds = np.power(preds,1./temperature)\n",
    "        preds = exp_preds / np.sum(exp_preds)\n",
    "        pro = np.random.choice(range(len(preds)),1,p=preds)\n",
    "        return int(pro.squeeze())\n",
    "    \n",
    "    def generate_sample_result(self, epoch, logs):\n",
    "        '''训练过程中，每个epoch打印出当前的学习情况'''\n",
    "        if epoch % 4 != 0:\n",
    "            return\n",
    "        \n",
    "        with open('out/out.txt', 'a',encoding='utf-8') as f:\n",
    "            f.write('==================Epoch {}=====================\\n'.format(epoch))\n",
    "                \n",
    "        print(\"\\n==================Epoch {}=====================\".format(epoch))\n",
    "        for diversity in [0.7, 1.0, 1.3]:\n",
    "            print(\"------------Diversity {}--------------\".format(diversity))\n",
    "            index = random.randint(0, self.poems_num )\n",
    "            sentence = self.poems[index][: self.config.max_len]\n",
    "            generated = str(sentence)\n",
    "            \n",
    "            for i in range(18):\n",
    "                x_pred = np.zeros((1, self.config.max_len, len(self.words)))\n",
    "                for t, char in enumerate(sentence):\n",
    "                    x_pred[0, t, self.word2numF(char)] = 1.\n",
    "\n",
    "                preds = self.model.predict(x_pred, verbose=0)[0]\n",
    "                next_index = self.sample(preds, diversity)\n",
    "                next_char = self.num2word[next_index]\n",
    "\n",
    "                generated += next_char\n",
    "                sentence = sentence[1:] + next_char\n",
    "            print(generated)\n",
    "            \n",
    "            with open('out/out.txt', 'a',encoding='utf-8') as f:\n",
    "                f.write(generated+'\\n')\n",
    "\n",
    "    \n",
    "    def predict(self, text):\n",
    "        '''根据给出的文字，生成诗句'''\n",
    "        if not self.loaded_model:\n",
    "            return\n",
    "        with open(self.config.poetry_file, 'r',encoding='UTF-8') as f:\n",
    "            file_list = f.readlines()\n",
    "        random_line = random.choice(file_list)\n",
    "        # 如果给的text不到四个字，则随机补全\n",
    "        if not text or len(text) != 4:\n",
    "            for _ in range(4 - len(text)):\n",
    "                random_str_index = random.randrange(0, len(self.words))\n",
    "                text += self.num2word.get(random_str_index) if self.num2word.get(random_str_index) not in [',', '。',\n",
    "                                                                                                           '，'] else self.num2word.get(\n",
    "                    random_str_index + 1)\n",
    "        seed = random_line.strip()[-(self.config.max_len):]\n",
    "\n",
    "        res = ''\n",
    "\n",
    "        for c in text:\n",
    "            seed = seed[1:] + c\n",
    "            print('seed before = ',seed)\n",
    "            for j in range(5):\n",
    "                x_pred = np.zeros((1, self.config.max_len, len(self.words)))\n",
    "                for t, char in enumerate(seed):\n",
    "                    x_pred[0, t, self.word2numF(char)] = 1.\n",
    "\n",
    "                preds = self.model.predict(x_pred, verbose=0)[0]\n",
    "                next_index = self.sample(preds, 1.0)\n",
    "                next_char = self.num2word[next_index]\n",
    "                seed = seed[1:] + next_char\n",
    "            print('seed after= ',seed)\n",
    "            res += seed\n",
    "        return res\n",
    "\n",
    "    def predict2(self, text):\n",
    "        '''根据给出的第一句诗，生成诗句'''\n",
    "        if not self.loaded_model:\n",
    "            return\n",
    "        max_len = self.config.max_len\n",
    "        if len(text)<max_len:\n",
    "            print('length should not be less than ',max_len)\n",
    "            return\n",
    "\n",
    "        line = text[-max_len:]\n",
    "        print('the first line:',line)\n",
    "        res = str(line)\n",
    "        for i in range(50):\n",
    "            x_pred = np.zeros((1,max_len, len(self.words)))\n",
    "            for t, char in enumerate(line):\n",
    "                x_pred[0, t, self.word2numF(char)] = 1.\n",
    "            \n",
    "            preds = self.model.predict(x_pred, verbose=0)[0]\n",
    "            next_index = self.sample(preds, 1.0)\n",
    "            next_char = self.num2word[next_index]\n",
    "            line = line[1:] + next_char\n",
    "            res += next_char\n",
    "        return res\n",
    "\n",
    "    def data_generator(self):\n",
    "        '''生成器生成数据'''\n",
    "        i = 0\n",
    "        while 1:\n",
    "            x = self.files_content[i: i + self.config.max_len]\n",
    "            y = self.files_content[i + self.config.max_len]\n",
    "\n",
    "            if ']' in x or ']' in y:\n",
    "                i += 1\n",
    "                continue\n",
    "\n",
    "            y_vec = np.zeros(\n",
    "                shape=(1, len(self.words)),\n",
    "                dtype=np.bool\n",
    "            )\n",
    "            y_vec[0, self.word2numF(y)] = 1.0\n",
    "\n",
    "            x_vec = np.zeros(\n",
    "                shape=(1, self.config.max_len, len(self.words)),\n",
    "                dtype=np.bool\n",
    "            )\n",
    "\n",
    "            for t, char in enumerate(x):\n",
    "                x_vec[0, t, self.word2numF(char)] = 1.0\n",
    "\n",
    "            yield x_vec, y_vec\n",
    "            i += 1\n",
    "\n",
    "    def train(self):\n",
    "        '''训练模型'''\n",
    "        print('training')\n",
    "        number_of_epoch = len(self.files_content)-(self.config.max_len + 1)*self.poems_num\n",
    "        number_of_epoch /= self.config.batch_size \n",
    "        number_of_epoch = int(number_of_epoch / 1.5)\n",
    "        print('epoches = ',number_of_epoch)\n",
    "        print('poems_num = ',self.poems_num)\n",
    "        print('len(self.files_content) = ',len(self.files_content))\n",
    "\n",
    "        if not self.model:\n",
    "            self.build_model()\n",
    "\n",
    "        self.model.fit_generator(\n",
    "            generator=self.data_generator(),\n",
    "            verbose=True,\n",
    "            steps_per_epoch=self.config.batch_size,\n",
    "            epochs=number_of_epoch,\n",
    "            callbacks=[\n",
    "                keras.callbacks.ModelCheckpoint(self.config.weight_file, save_weights_only=False),\n",
    "                LambdaCallback(on_epoch_end=self.generate_sample_result)\n",
    "            ]\n",
    "        )\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training\n",
      "epoches =  34858\n",
      "poems_num =  24027\n",
      "len(self.files_content) =  1841397\n",
      "building model\n",
      "Epoch 1/34858\n",
      "32/32 [==============================] - 73s 2s/step - loss: 8.6177 - acc: 0.0312\n",
      "\n",
      "==================Epoch 0=====================\n",
      "------------Diversity 0.7--------------\n",
      "一醉六十日，累闉强岳夷涂迁号倪蜂溟瞳供蜉嘱葳处务\n",
      "------------Diversity 1.0--------------\n",
      "茅屋往来久，底茎皮髻宋释祜郑蒯邮岁磈摸茯奠孔黮试\n",
      "------------Diversity 1.3--------------\n",
      "大位天下宝，清剿槟冒愤世洽肮猊：怕占棨淫继渐雎饬\n",
      "Epoch 2/34858\n",
      "32/32 [==============================] - 32s 985ms/step - loss: 8.1601 - acc: 0.0625\n",
      "Epoch 3/34858\n",
      "32/32 [==============================] - 32s 996ms/step - loss: 7.4346 - acc: 0.0938\n",
      "Epoch 4/34858\n",
      "32/32 [==============================] - 32s 985ms/step - loss: 8.0106 - acc: 0.0938\n",
      "Epoch 5/34858\n",
      "32/32 [==============================] - 32s 985ms/step - loss: 7.8604 - acc: 0.0625\n",
      "\n",
      "==================Epoch 4=====================\n",
      "------------Diversity 0.7--------------\n",
      "禅宫新歇雨，，缥，。。媪。，。。。孕，，。初朔。\n",
      "------------Diversity 1.0--------------\n",
      "上林新柳变，仁瑕齐忻(递还解挛权风牧隗园铸浅芥者\n",
      "------------Diversity 1.3--------------\n",
      "远客坐长夜，冏溺低酎碧澧聚陪且勒达哂渤前涬葬隔娭\n",
      "Epoch 6/34858\n",
      "32/32 [==============================] - 32s 994ms/step - loss: 7.6186 - acc: 0.0938\n",
      "Epoch 7/34858\n",
      "32/32 [==============================] - 32s 1000ms/step - loss: 7.5555 - acc: 0.0625\n",
      "Epoch 8/34858\n",
      "32/32 [==============================] - 32s 990ms/step - loss: 7.4748 - acc: 0.0000e+00\n",
      "Epoch 9/34858\n",
      "32/32 [==============================] - 31s 984ms/step - loss: 7.4204 - acc: 0.1250\n",
      "\n",
      "==================Epoch 8=====================\n",
      "------------Diversity 0.7--------------\n",
      "薄命头欲白，懿埃婵锤阕，琚騂花腊鸟，茗瓠嗥日絺。\n",
      "------------Diversity 1.0--------------\n",
      "晨趋紫禁中，批觏蕃嵯他乏晋跃查麈愧渍馹滂芹恕橡汾\n",
      "------------Diversity 1.3--------------\n",
      "暴雨逐惊雷，睡輶牡鞋酩髓痕缵碾京鶂舼濠裀嵬嗜躅遮\n",
      "Epoch 10/34858\n",
      "32/32 [==============================] - 32s 991ms/step - loss: 7.5820 - acc: 0.0312\n",
      "Epoch 11/34858\n",
      "32/32 [==============================] - 31s 979ms/step - loss: 7.2330 - acc: 0.0000e+00\n",
      "Epoch 12/34858\n",
      "32/32 [==============================] - 32s 988ms/step - loss: 7.0780 - acc: 0.1250\n",
      "Epoch 13/34858\n",
      "32/32 [==============================] - 31s 979ms/step - loss: 7.3093 - acc: 0.0938\n",
      "\n",
      "==================Epoch 12=====================\n",
      "------------Diversity 0.7--------------\n",
      "分险架长澜，日，鱼鸟。，。，百臼，，，。日，。。\n",
      "------------Diversity 1.0--------------\n",
      "忧勤承圣绪，袭穹来劬。，蒯曙攫，。。濞崟，。。。\n",
      "------------Diversity 1.3--------------\n",
      "桓公名已古，龃邓魈旃别，遄破雪螮月，产新谷通梦脆\n",
      "Epoch 14/34858\n",
      "32/32 [==============================] - 31s 980ms/step - loss: 7.8563 - acc: 0.0625\n",
      "Epoch 15/34858\n",
      "32/32 [==============================] - 32s 985ms/step - loss: 7.1950 - acc: 0.0938\n",
      "Epoch 16/34858\n",
      "32/32 [==============================] - 31s 984ms/step - loss: 7.2782 - acc: 0.0625\n",
      "Epoch 17/34858\n",
      "32/32 [==============================] - 31s 978ms/step - loss: 7.2457 - acc: 0.1250\n",
      "\n",
      "==================Epoch 16=====================\n",
      "------------Diversity 0.7--------------\n",
      "佳期曾不远，歌知读筦网，理狸观鹤蔟，移诈灞剥悉，\n",
      "------------Diversity 1.0--------------\n",
      "仲月景气佳，保缥畦怵腑。懵昆垦也锥。额与俦闺乐。\n",
      "------------Diversity 1.3--------------\n",
      "孔坐洽良俦，械亥超吓谋，趾晼写轸嶷，妻囚汲匄械逻\n",
      "Epoch 18/34858\n",
      "32/32 [==============================] - 31s 975ms/step - loss: 7.1581 - acc: 0.1875\n",
      "Epoch 19/34858\n",
      "32/32 [==============================] - 31s 981ms/step - loss: 7.2665 - acc: 0.0938\n",
      "Epoch 20/34858\n",
      "32/32 [==============================] - 31s 977ms/step - loss: 7.1639 - acc: 0.1250\n",
      "Epoch 21/34858\n",
      "32/32 [==============================] - 31s 975ms/step - loss: 6.7273 - acc: 0.0938\n",
      "\n",
      "==================Epoch 20=====================\n",
      "------------Diversity 0.7--------------\n",
      "李洞僻相似，飞虚罔蝶接，，沟初花，，。风飞飞，，\n",
      "------------Diversity 1.0--------------\n",
      "风沙悲久戍，蜉，管飞验，愉，园襜色，溧鸟千旸觊，\n",
      "------------Diversity 1.3--------------\n",
      "郁郁复郁郁，泠池胶原龙贿侈披埒古蟋魑淬消惬废醲散\n",
      "Epoch 22/34858\n",
      "32/32 [==============================] - 31s 978ms/step - loss: 7.0865 - acc: 0.1250\n",
      "Epoch 23/34858\n",
      "32/32 [==============================] - 31s 977ms/step - loss: 6.9763 - acc: 0.0625\n",
      "Epoch 24/34858\n",
      "32/32 [==============================] - 31s 976ms/step - loss: 6.7768 - acc: 0.0938\n",
      "Epoch 25/34858\n",
      "32/32 [==============================] - 31s 974ms/step - loss: 7.2370 - acc: 0.0938\n",
      "\n",
      "==================Epoch 24=====================\n",
      "------------Diversity 0.7--------------\n",
      "试入山亭望，翠牖稼谅远，翠脏熏肿芳。月箭色玉。，\n",
      "------------Diversity 1.0--------------\n",
      "三十事诸侯，弯赃横落翠。琏糇阎驾杂。筦鱼皇千瑙。\n",
      "------------Diversity 1.3--------------\n",
      "前不见古人，圄苾蹑骢灉祷偃淤樽幼洽瘢氛挥荤贰鹤损\n",
      "Epoch 26/34858\n",
      "32/32 [==============================] - 32s 985ms/step - loss: 7.6505 - acc: 0.0938\n",
      "Epoch 27/34858\n",
      "32/32 [==============================] - 31s 977ms/step - loss: 7.1338 - acc: 0.1250\n",
      "Epoch 28/34858\n",
      "32/32 [==============================] - 31s 980ms/step - loss: 7.4553 - acc: 0.0625\n",
      "Epoch 29/34858\n",
      "32/32 [==============================] - 31s 977ms/step - loss: 6.9534 - acc: 0.0938\n",
      "\n",
      "==================Epoch 28=====================\n",
      "------------Diversity 0.7--------------\n",
      "回首望知音，芳去维通去，曦明桂林胪。疏雕牵芳出。\n",
      "------------Diversity 1.0--------------\n",
      "不资冬日秀，豳四饧渎稚。赏叉赏赦邹。衿姝鹴林鸣。\n",
      "------------Diversity 1.3--------------\n",
      "昔有嵇氏子，每赏坎星耄。隼光称椰徇。菅鱼欤恒减。\n",
      "Epoch 30/34858\n",
      "32/32 [==============================] - 31s 980ms/step - loss: 7.1779 - acc: 0.1562\n",
      "Epoch 31/34858\n",
      "32/32 [==============================] - 31s 979ms/step - loss: 8.0979 - acc: 0.0625\n",
      "Epoch 32/34858\n",
      "32/32 [==============================] - 31s 976ms/step - loss: 7.5519 - acc: 0.0625\n",
      "Epoch 33/34858\n",
      "32/32 [==============================] - 31s 984ms/step - loss: 7.0965 - acc: 0.1250\n",
      "\n",
      "==================Epoch 32=====================\n",
      "------------Diversity 0.7--------------\n",
      "漕水东去远，赏出岿节烛落落高嘲铓犹冥云移等翁噞云\n",
      "------------Diversity 1.0--------------\n",
      "独绕虚斋径，驹裁翘咨魉满村祐古殳柯接烟阙泚绮鲭檗\n",
      "------------Diversity 1.3--------------\n",
      "殷辛帝天下，每蒿馥竺蠓资洗马壈参草先悸塿伞踟雷舂\n",
      "Epoch 34/34858\n",
      "32/32 [==============================] - 31s 980ms/step - loss: 7.0554 - acc: 0.1250\n",
      "Epoch 35/34858\n",
      "32/32 [==============================] - 31s 974ms/step - loss: 7.2609 - acc: 0.1562\n",
      "Epoch 36/34858\n",
      "32/32 [==============================] - 31s 980ms/step - loss: 7.3603 - acc: 0.1875\n",
      "Epoch 37/34858\n",
      "32/32 [==============================] - 31s 979ms/step - loss: 7.2900 - acc: 0.1875\n",
      "\n",
      "==================Epoch 36=====================\n",
      "------------Diversity 0.7--------------\n",
      "浮光上东洛，悠衣祉等牌。枕辩河色惊，晋月倍急泠。\n",
      "------------Diversity 1.0--------------\n",
      "春风泛摇草，碧莓懒仇睡日类酣玩妻紫阳噀楼嘒业月藏\n",
      "------------Diversity 1.3--------------\n",
      "遭乱发尽白，谅川篲兟嵇。瞪脍鳅濯冤，鼯蔬成影陨梅\n",
      "Epoch 38/34858\n",
      "32/32 [==============================] - 31s 975ms/step - loss: 7.4245 - acc: 0.1562\n",
      "Epoch 39/34858\n",
      "32/32 [==============================] - 31s 976ms/step - loss: 6.8849 - acc: 0.1875\n",
      "Epoch 40/34858\n",
      "32/32 [==============================] - 31s 984ms/step - loss: 7.2636 - acc: 0.1562\n",
      "Epoch 41/34858\n",
      "32/32 [==============================] - 31s 975ms/step - loss: 7.0696 - acc: 0.1562\n",
      "\n",
      "==================Epoch 40=====================\n",
      "------------Diversity 0.7--------------\n",
      "晰晰燎火光，彻餗插皑渝。数方翠窟芜，阙来衍声阴。\n",
      "------------Diversity 1.0--------------\n",
      "东汉兴唐历，箑方兴茶莺。枋双聃凰辙，藿蕴欃荔四病\n",
      "------------Diversity 1.3--------------\n",
      "重轩深似谷，镬罢桩潮画毁眄去簏礴硬级么临刊跗诱色\n",
      "Epoch 42/34858\n",
      "32/32 [==============================] - 31s 982ms/step - loss: 6.9486 - acc: 0.1562\n",
      "Epoch 43/34858\n",
      "32/32 [==============================] - 31s 980ms/step - loss: 7.2795 - acc: 0.1250\n",
      "Epoch 44/34858\n",
      "32/32 [==============================] - 31s 978ms/step - loss: 7.0036 - acc: 0.1250\n",
      "Epoch 45/34858\n",
      "32/32 [==============================] - 31s 976ms/step - loss: 6.8045 - acc: 0.1875\n",
      "\n",
      "==================Epoch 44=====================\n",
      "------------Diversity 0.7--------------\n",
      "非君惜鸾殿，径出寞蟆色。不雾凉文僮，自灼梭销宫，\n",
      "------------Diversity 1.0--------------\n",
      "一食复何如，寄偲房日旄，汜戈借辘呦。格寄禊瘿觌，\n",
      "------------Diversity 1.3--------------\n",
      "泻月声不断，圭属复就响百姓禹危城喷步蜮弛烛朱珪砑\n",
      "Epoch 46/34858\n",
      "32/32 [==============================] - 31s 977ms/step - loss: 6.9364 - acc: 0.1562\n",
      "Epoch 47/34858\n",
      "32/32 [==============================] - 32s 987ms/step - loss: 7.0084 - acc: 0.1562\n",
      "Epoch 48/34858\n",
      "32/32 [==============================] - 31s 979ms/step - loss: 6.6629 - acc: 0.1875\n",
      "Epoch 49/34858\n",
      "32/32 [==============================] - 31s 977ms/step - loss: 6.6649 - acc: 0.1562\n",
      "\n",
      "==================Epoch 48=====================\n",
      "------------Diversity 0.7--------------\n",
      "九月三十日，乡驻古风辰。芳古笳山残，无峤粱，书。\n",
      "------------Diversity 1.0--------------\n",
      "咏歌有离合，峡飞倍珍去。隐独率圃仰，轻皓猿岳，。\n",
      "------------Diversity 1.3--------------\n",
      "否极生大贤，日帔摐祇腥。泻柄疮虚移防怡飘殁粒垫馈\n",
      "Epoch 50/34858\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "32/32 [==============================] - 31s 976ms/step - loss: 6.9402 - acc: 0.1562\n",
      "Epoch 51/34858\n",
      "32/32 [==============================] - 31s 971ms/step - loss: 6.7699 - acc: 0.1875\n",
      "Epoch 52/34858\n",
      "32/32 [==============================] - 31s 972ms/step - loss: 6.9880 - acc: 0.1562\n",
      "Epoch 53/34858\n",
      "32/32 [==============================] - 319s 10s/step - loss: 6.5778 - acc: 0.1562\n",
      "\n",
      "==================Epoch 52=====================\n",
      "------------Diversity 0.7--------------\n",
      "寂寞吾庐贫，朝卧焉凉衣。邑笳淳抚澄，开锦伫文至。\n",
      "------------Diversity 1.0--------------\n",
      "宠至乃不惊，响匮苦综星。悦瑶逢旧售，持。蕣蝝啅。\n",
      "------------Diversity 1.3--------------\n",
      "城楼枕南浦，芹犀搴琥饮。树天栌临枝攫朣泚騄摘索游\n",
      "Epoch 54/34858\n",
      "32/32 [==============================] - 32s 1s/step - loss: 6.4135 - acc: 0.1875\n",
      "Epoch 55/34858\n",
      "16/32 [==============>...............] - ETA: 17s - loss: 6.8577 - acc: 0.1250"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-14-94a25cf5fb81>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0mconfig\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mConfig\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m \u001b[0mmodel\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mPoetryModel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mConfig\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;32m<ipython-input-13-7a27c9a57730>\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, config)\u001b[0m\n\u001b[0;32m     16\u001b[0m             \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmodel\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mload_model\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mweight_file\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     17\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 18\u001b[1;33m             \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     19\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     20\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mbuild_model\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<ipython-input-13-7a27c9a57730>\u001b[0m in \u001b[0;36mtrain\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    184\u001b[0m             callbacks=[\n\u001b[0;32m    185\u001b[0m                 \u001b[0mkeras\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcallbacks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mModelCheckpoint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mweight_file\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msave_weights_only\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 186\u001b[1;33m                 \u001b[0mLambdaCallback\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mon_epoch_end\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgenerate_sample_result\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    187\u001b[0m             ]\n\u001b[0;32m    188\u001b[0m         )\n",
      "\u001b[1;32mf:\\pythonprojects\\lenv\\test\\lib\\site-packages\\keras\\legacy\\interfaces.py\u001b[0m in \u001b[0;36mwrapper\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m     89\u001b[0m                 warnings.warn('Update your `' + object_name +\n\u001b[0;32m     90\u001b[0m                               '` call to the Keras 2 API: ' + signature, stacklevel=2)\n\u001b[1;32m---> 91\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     92\u001b[0m         \u001b[0mwrapper\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_original_function\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     93\u001b[0m         \u001b[1;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mf:\\pythonprojects\\lenv\\test\\lib\\site-packages\\keras\\engine\\training.py\u001b[0m in \u001b[0;36mfit_generator\u001b[1;34m(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)\u001b[0m\n\u001b[0;32m   2242\u001b[0m                     outs = self.train_on_batch(x, y,\n\u001b[0;32m   2243\u001b[0m                                                \u001b[0msample_weight\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0msample_weight\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 2244\u001b[1;33m                                                class_weight=class_weight)\n\u001b[0m\u001b[0;32m   2245\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   2246\u001b[0m                     \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mouts\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mf:\\pythonprojects\\lenv\\test\\lib\\site-packages\\keras\\engine\\training.py\u001b[0m in \u001b[0;36mtrain_on_batch\u001b[1;34m(self, x, y, sample_weight, class_weight)\u001b[0m\n\u001b[0;32m   1888\u001b[0m             \u001b[0mins\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mx\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0my\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0msample_weights\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1889\u001b[0m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_make_train_function\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1890\u001b[1;33m         \u001b[0moutputs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtrain_function\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mins\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1891\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1892\u001b[0m             \u001b[1;32mreturn\u001b[0m \u001b[0moutputs\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mf:\\pythonprojects\\lenv\\test\\lib\\site-packages\\keras\\backend\\tensorflow_backend.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, inputs)\u001b[0m\n\u001b[0;32m   2473\u001b[0m         \u001b[0msession\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mget_session\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   2474\u001b[0m         updated = session.run(fetches=fetches, feed_dict=feed_dict,\n\u001b[1;32m-> 2475\u001b[1;33m                               **self.session_kwargs)\n\u001b[0m\u001b[0;32m   2476\u001b[0m         \u001b[1;32mreturn\u001b[0m \u001b[0mupdated\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0moutputs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   2477\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mf:\\pythonprojects\\lenv\\test\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36mrun\u001b[1;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[0;32m    903\u001b[0m     \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    904\u001b[0m       result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[1;32m--> 905\u001b[1;33m                          run_metadata_ptr)\n\u001b[0m\u001b[0;32m    906\u001b[0m       \u001b[1;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    907\u001b[0m         \u001b[0mproto_data\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mf:\\pythonprojects\\lenv\\test\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_run\u001b[1;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[0;32m   1135\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0mfinal_fetches\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0mfinal_targets\u001b[0m \u001b[1;32mor\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mhandle\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0mfeed_dict_tensor\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1136\u001b[0m       results = self._do_run(handle, final_targets, final_fetches,\n\u001b[1;32m-> 1137\u001b[1;33m                              feed_dict_tensor, options, run_metadata)\n\u001b[0m\u001b[0;32m   1138\u001b[0m     \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1139\u001b[0m       \u001b[0mresults\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mf:\\pythonprojects\\lenv\\test\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_do_run\u001b[1;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[0;32m   1353\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0mhandle\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1354\u001b[0m       return self._do_call(_run_fn, self._session, feeds, fetches, targets,\n\u001b[1;32m-> 1355\u001b[1;33m                            options, run_metadata)\n\u001b[0m\u001b[0;32m   1356\u001b[0m     \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1357\u001b[0m       \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_do_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0m_prun_fn\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_session\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeeds\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfetches\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mf:\\pythonprojects\\lenv\\test\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_do_call\u001b[1;34m(self, fn, *args)\u001b[0m\n\u001b[0;32m   1359\u001b[0m   \u001b[1;32mdef\u001b[0m \u001b[0m_do_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1360\u001b[0m     \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1361\u001b[1;33m       \u001b[1;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1362\u001b[0m     \u001b[1;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mOpError\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1363\u001b[0m       \u001b[0mmessage\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcompat\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0me\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmessage\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mf:\\pythonprojects\\lenv\\test\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_run_fn\u001b[1;34m(session, feed_dict, fetch_list, target_list, options, run_metadata)\u001b[0m\n\u001b[0;32m   1338\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1339\u001b[0m           return tf_session.TF_Run(session, options, feed_dict, fetch_list,\n\u001b[1;32m-> 1340\u001b[1;33m                                    target_list, status, run_metadata)\n\u001b[0m\u001b[0;32m   1341\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1342\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0m_prun_fn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msession\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "from config import Config\n",
    "model = PoetryModel(Config)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "None\n"
     ]
    }
   ],
   "source": [
    "from config import Config\n",
    "model = PoetryModel(Config)\n",
    "sentence = model.predict2('空山新雨后，')\n",
    "print(sentence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "temperature = 0.8\n",
    "preds = [0.1,0.2,0.3,0.4]\n",
    "preds = np.asarray(preds).astype('float64')\n",
    "preds = np.log(preds) / temperature\n",
    "exp_preds = np.exp(preds)\n",
    "preds = exp_preds / np.sum(exp_preds)\n",
    "print(preds)\n",
    "pro = np.random.choice(range(len(preds)),1,p=preds)\n",
    "# probas = np.random.multinomial(1, preds, 1)\n",
    "print(int(pro.squeeze()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "temperature = 0.8\n",
    "preds = [0.1,0.2,0.3,0.4]\n",
    "preds = np.asarray(preds).astype('float64')\n",
    "exp_preds = np.power(preds,1./temperature)\n",
    "preds = exp_preds / np.sum(exp_preds)\n",
    "print(preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('out/out.txt', 'a',encoding='utf-8') as f:\n",
    "    for i in range(5):\n",
    "        f.write('你好啊'+'\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
